-
Notifications
You must be signed in to change notification settings - Fork 482
Add state_dict converter for DeepSeekv3 in torchtitan #1538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -16,12 +16,12 @@ | |||
from tokenizer.tiktoken import BaseTokenizer, IGNORE_INDEX | |||
from torch.distributed.checkpoint.stateful import Stateful | |||
from torch.utils.data import IterableDataset | |||
from transform import CLIPTransform | |||
from utils import load_image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: This is because of I ran pre-commit
@@ -282,10 +282,12 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): | |||
self.register_buffer( | |||
"expert_bias", | |||
torch.zeros(num_experts, dtype=torch.float32), | |||
persistent=True, | |||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: Explicitly add whether the registered buffer is persistent. When false
, we are not expected to load from DCP checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe need to rebase onto #1526 after it lands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the conversion
- can be used to offline convert HF checkpoint from fp8 to fp32 using CPU plain tensor.
- can't be used to convert HF checkpoint on the fly using GPU DTensor, because of sharding and quantized blocks may not be aligned well.
- can't be used for weight sync to generate a state dict of bf16 because fake quantization to fp8 is applied.
I think it's OK to land this PR to unblock 1, but better to explain things clearly somewhere.
I also had some inline comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM to unblock
Support loading a DeepSeek HF weights to Deepseek-V3 model:
Numerical verification: (using offline conversion script)